package testutil

import (
	"fmt"
	"strconv"
	"strings"

	v1 "k8s.io/api/core/v1"
	"k8s.io/apimachinery/pkg/api/resource"

	"github.com/linkerd/linkerd2/pkg/k8s"
)

const enabled = "true"
const proxyContainerName = "linkerd-proxy"
const initContainerName = "linkerd-init"
const debugContainerName = "linkerd-debug"

// InjectValidator is used as a helper to generate
// correct injector flags and annotations and verify
// injected pods
type InjectValidator struct {
	NoInitContainer          bool
	DisableIdentity          bool
	AutoInject               bool
	AdminPort                int
	ControlPort              int
	DisableTap               bool
	EnableDebug              bool
	EnableExternalProfiles   bool
	ImagePullPolicy          string
	InboundPort              int
	InitImage                string
	InitImageVersion         string
	OutboundPort             int
	CPULimit                 string
	CPURequest               string
	MemoryLimit              string
	MemoryRequest            string
	Image                    string
	LogLevel                 string
	LogFormat                string
	UID                      int
	Version                  string
	RequireIdentityOnPorts   string
	SkipOutboundPorts        string
	OpaquePorts              string
	SkipInboundPorts         string
	TraceCollector           string
	TraceCollectorSvcAccount string
	OutboundConnectTimeout   string
	InboundConnectTimeout    string
	WaitBeforeExitSeconds    int
}

func (iv *InjectValidator) getContainer(pod *v1.PodSpec, name string, isInit bool) *v1.Container {
	containers := pod.Containers
	if isInit {
		containers = pod.InitContainers
	}
	for _, container := range containers {
		if container.Name == name {
			return &container
		}
	}
	return nil
}

func (iv *InjectValidator) validateEnvVar(container *v1.Container, envName, expectedValue string) error {
	for _, env := range container.Env {
		if env.Name == envName {
			if env.Value == expectedValue {
				return nil
			}
			return fmt.Errorf("env: %s, expected: %s, actual %s", envName, expectedValue, env.Value)
		}

	}
	return fmt.Errorf("cannot find env: %s", envName)
}

func (iv *InjectValidator) validatePort(container *v1.Container, portName string, expectedValue int) error {
	for _, port := range container.Ports {
		if port.Name == portName {
			if port.ContainerPort == int32(expectedValue) {
				return nil
			}
			return fmt.Errorf("port: %s, expected: %d, actual %d", portName, expectedValue, port.ContainerPort)
		}

	}
	return fmt.Errorf("cannot find port: %s", portName)
}

func (iv *InjectValidator) validateDebugContainer(pod *v1.PodSpec) error {
	if iv.EnableDebug {
		proxyContainer := iv.getContainer(pod, debugContainerName, false)
		if proxyContainer == nil {
			return fmt.Errorf("container %s missing", debugContainerName)
		}
	}
	return nil
}

func (iv *InjectValidator) validateProxyContainer(pod *v1.PodSpec) error {
	proxyContainer := iv.getContainer(pod, proxyContainerName, false)
	if proxyContainer == nil {
		return fmt.Errorf("container %s missing", proxyContainerName)
	}

	if iv.AdminPort != 0 {
		if err := iv.validateEnvVar(proxyContainer, "LINKERD2_PROXY_ADMIN_LISTEN_ADDR", fmt.Sprintf("0.0.0.0:%d", iv.AdminPort)); err != nil {
			return err
		}
		if proxyContainer.LivenessProbe.HTTPGet.Port.IntVal != int32(iv.AdminPort) {
			return fmt.Errorf("livenessProbe: expected: %d, actual %d", iv.AdminPort, proxyContainer.LivenessProbe.HTTPGet.Port.IntVal)
		}
		if proxyContainer.ReadinessProbe.HTTPGet.Port.IntVal != int32(iv.AdminPort) {
			return fmt.Errorf("readinessProbe: expected: %d, actual %d", iv.AdminPort, proxyContainer.LivenessProbe.HTTPGet.Port.IntVal)
		}

		if err := iv.validatePort(proxyContainer, "linkerd-admin", iv.AdminPort); err != nil {
			return err
		}
	}

	if iv.ControlPort != 0 {
		if err := iv.validateEnvVar(proxyContainer, "LINKERD2_PROXY_CONTROL_LISTEN_ADDR", fmt.Sprintf("0.0.0.0:%d", iv.ControlPort)); err != nil {
			return err
		}
	}

	if iv.DisableTap {
		if err := iv.validateEnvVar(proxyContainer, "LINKERD2_PROXY_TAP_DISABLED", enabled); err != nil {
			return err
		}
	}

	if iv.EnableExternalProfiles {
		if err := iv.validateEnvVar(proxyContainer, "LINKERD2_PROXY_DESTINATION_PROFILE_SUFFIXES", "."); err != nil {
			return err
		}
	}

	if iv.ImagePullPolicy != "" {
		if string(proxyContainer.ImagePullPolicy) != iv.ImagePullPolicy {
			return fmt.Errorf("pullPolicy: expected: %s, actual %s", iv.ImagePullPolicy, string(proxyContainer.ImagePullPolicy))
		}
	}

	if iv.InboundPort != 0 {
		if err := iv.validateEnvVar(proxyContainer, "LINKERD2_PROXY_INBOUND_LISTEN_ADDR", fmt.Sprintf("0.0.0.0:%d", iv.InboundPort)); err != nil {
			return err
		}
		if proxyContainer.LivenessProbe.HTTPGet.Port.IntVal != int32(iv.AdminPort) {
			return fmt.Errorf("livenessProbe: expected: %d, actual %d", iv.AdminPort, proxyContainer.LivenessProbe.HTTPGet.Port.IntVal)
		}
		if proxyContainer.ReadinessProbe.HTTPGet.Port.IntVal != int32(iv.AdminPort) {
			return fmt.Errorf("readinessProbe: expected: %d, actual %d", iv.AdminPort, proxyContainer.LivenessProbe.HTTPGet.Port.IntVal)
		}
		if err := iv.validatePort(proxyContainer, "linkerd-proxy", iv.InboundPort); err != nil {
			return err
		}
	}

	if iv.OutboundPort != 0 {
		if err := iv.validateEnvVar(proxyContainer, "LINKERD2_PROXY_OUTBOUND_LISTEN_ADDR", fmt.Sprintf("127.0.0.1:%d", iv.OutboundPort)); err != nil {
			return err
		}
	}

	if iv.CPULimit != "" {
		limit := resource.MustParse(iv.CPULimit)
		if proxyContainer.Resources.Limits.Cpu() != nil {
			if !proxyContainer.Resources.Limits.Cpu().Equal(limit) {
				return fmt.Errorf("CpuLimit: expected %v, actual %v", &limit, proxyContainer.Resources.Limits.Cpu())
			}
		} else {
			return fmt.Errorf("CpuLimit: expected %v, but none", &limit)
		}

	}

	if iv.CPURequest != "" {
		request := resource.MustParse(iv.CPURequest)
		if proxyContainer.Resources.Requests.Cpu() != nil {
			if !proxyContainer.Resources.Requests.Cpu().Equal(request) {
				return fmt.Errorf("CpuRequest: expected %v, actual %v", &request, proxyContainer.Resources.Requests.Cpu())
			}
		} else {
			return fmt.Errorf("CpuRequest: expected %v, but none", &request)
		}
	}

	if iv.MemoryLimit != "" {
		limit := resource.MustParse(iv.MemoryLimit)
		if proxyContainer.Resources.Limits.Memory() != nil {
			if !proxyContainer.Resources.Limits.Memory().Equal(limit) {
				return fmt.Errorf("MemLimit: expected %v, actual %v", &limit, proxyContainer.Resources.Limits.Memory())
			}
		} else {
			return fmt.Errorf("MemLimit: expected %v, but none", &limit)
		}
	}

	if iv.MemoryRequest != "" {
		request := resource.MustParse(iv.MemoryRequest)
		if proxyContainer.Resources.Requests.Memory() != nil {
			if !proxyContainer.Resources.Requests.Memory().Equal(request) {
				return fmt.Errorf("MemRequest: expected %v, actual %v", &request, proxyContainer.Resources.Requests.Memory())
			}
		} else {
			return fmt.Errorf("MemRequest: expected %v, but none", &request)
		}
	}

	if iv.Image != "" || iv.Version != "" {
		image := strings.Split(proxyContainer.Image, ":")

		if len(image) != 2 {
			return fmt.Errorf("invalid proxy container image string: %s", proxyContainer.Image)
		}

		if iv.Image != "" {
			if image[0] != iv.Image {
				return fmt.Errorf("proxyImage: expected %s, actual %s", iv.Image, image[0])
			}
		}

		if iv.Version != "" {
			if image[1] != iv.Version {
				return fmt.Errorf("proxyImageVersion: expected %s, actual %s", iv.Version, image[1])
			}
		}
	}

	if iv.LogLevel != "" {
		if err := iv.validateEnvVar(proxyContainer, "LINKERD2_PROXY_LOG", iv.LogLevel); err != nil {
			return err
		}
	}

	if iv.LogFormat != "" {
		if err := iv.validateEnvVar(proxyContainer, "LINKERD2_PROXY_LOG_FORMAT", iv.LogFormat); err != nil {
			return err
		}
	}

	if iv.UID != 0 {
		if proxyContainer.SecurityContext.RunAsUser == nil {
			return fmt.Errorf("no RunAsUser specified")
		}
		if *proxyContainer.SecurityContext.RunAsUser != int64(iv.UID) {
			return fmt.Errorf("runAsUser: expected %d, actual %d", iv.UID, *proxyContainer.SecurityContext.RunAsUser)
		}
	}

	if iv.RequireIdentityOnPorts != "" {
		if err := iv.validateEnvVar(proxyContainer, "LINKERD2_PROXY_INBOUND_PORTS_REQUIRE_IDENTITY", iv.RequireIdentityOnPorts); err != nil {
			return err
		}
	}

	if iv.OpaquePorts != "" {
		if err := iv.validateEnvVar(proxyContainer, "LINKERD2_PROXY_INBOUND_PORTS_DISABLE_PROTOCOL_DETECTION", iv.OpaquePorts); err != nil {
			return err
		}
	}

	if iv.TraceCollector != "" {
		if err := iv.validateEnvVar(proxyContainer, "LINKERD2_PROXY_TRACE_COLLECTOR_SVC_ADDR", iv.TraceCollector); err != nil {
			return err
		}
	}

	if iv.TraceCollectorSvcAccount != "" {
		expectedAccount := fmt.Sprintf("%s.tracing.serviceaccount.identity.$(_l5d_ns).$(_l5d_trustdomain)", iv.TraceCollectorSvcAccount)
		if err := iv.validateEnvVar(proxyContainer, "LINKERD2_PROXY_TRACE_COLLECTOR_SVC_NAME", expectedAccount); err != nil {
			return err
		}
	}

	if iv.OutboundConnectTimeout != "" {
		if err := iv.validateEnvVar(proxyContainer, "LINKERD2_PROXY_OUTBOUND_CONNECT_TIMEOUT", iv.OutboundConnectTimeout); err != nil {
			return err
		}
	}

	if iv.OutboundConnectTimeout != "" {
		if err := iv.validateEnvVar(proxyContainer, "LINKERD2_PROXY_INBOUND_CONNECT_TIMEOUT", iv.InboundConnectTimeout); err != nil {
			return err
		}
	}

	if iv.WaitBeforeExitSeconds != 0 {
		expectedCmd := fmt.Sprintf("/bin/bash,-c,sleep %d", iv.WaitBeforeExitSeconds)
		actual := strings.Join(proxyContainer.Lifecycle.PreStop.Exec.Command, ",")
		if expectedCmd != strings.Join(proxyContainer.Lifecycle.PreStop.Exec.Command, ",") {
			return fmt.Errorf("preStopHook: expected %s, actual %s", expectedCmd, actual)
		}
	}

	return nil
}

func (iv *InjectValidator) validateInitContainer(pod *v1.PodSpec) error {
	if iv.NoInitContainer {
		return nil
	}
	initContainer := iv.getContainer(pod, initContainerName, true)
	if initContainer == nil {
		return fmt.Errorf("container %s missing", initContainerName)
	}

	if iv.InitImage != "" || iv.InitImageVersion != "" {

		image := strings.Split(initContainer.Image, ":")

		if len(image) != 2 {
			return fmt.Errorf("invalid proxy init image string: %s", initContainer.Image)
		}

		if iv.InitImage != "" {
			if image[0] != iv.InitImage {
				return fmt.Errorf("proxyInitImage: expected %s, actual %s", iv.InitImage, image[0])
			}
		}

		if iv.InitImageVersion != "" {
			if image[1] != iv.InitImageVersion {
				return fmt.Errorf("proxyInitImageVersion: expected %s, actual %s", iv.InitImageVersion, image[1])
			}
		}
	}

	if iv.InboundPort != 0 {
		if err := iv.validateArg(initContainer, "--incoming-proxy-port", strconv.Itoa(iv.InboundPort)); err != nil {
			return err
		}
	}

	if iv.OutboundPort != 0 {
		if err := iv.validateArg(initContainer, "--proxy-uid", strconv.Itoa(iv.UID)); err != nil {
			return err
		}
	}

	if iv.UID != 0 {
		if err := iv.validateArg(initContainer, "--outgoing-proxy-port", strconv.Itoa(iv.OutboundPort)); err != nil {
			return err
		}
	}

	if iv.SkipInboundPorts != "" {
		expectedPorts := fmt.Sprintf("%d,%d,%s", iv.ControlPort, iv.AdminPort, iv.SkipInboundPorts)
		if err := iv.validateArg(initContainer, "--inbound-ports-to-ignore", expectedPorts); err != nil {
			return err
		}
	}

	if iv.SkipOutboundPorts != "" {
		if err := iv.validateArg(initContainer, "--outbound-ports-to-ignore", iv.SkipOutboundPorts); err != nil {
			return err
		}
	}

	return nil
}

func (iv *InjectValidator) validateArg(container *v1.Container, argName, expectedValue string) error {
	for i, arg := range container.Args {
		if arg == argName {
			if len(container.Args) < i+2 {
				return fmt.Errorf("No value for arg %s", argName)
			}
			if container.Args[i+1] != expectedValue {
				return fmt.Errorf("container arg %s expected: %s, actual %s", argName, expectedValue, container.Args[i+1])
			}
			return nil
		}
	}

	return fmt.Errorf("Could not find arg: %s", argName)

}

// ValidatePod validates that the pod had been configured
// according by the injector correctly
func (iv *InjectValidator) ValidatePod(pod *v1.PodSpec) error {

	if err := iv.validateProxyContainer(pod); err != nil {
		return err
	}

	if err := iv.validateInitContainer(pod); err != nil {
		return err
	}

	if err := iv.validateDebugContainer(pod); err != nil {
		return err
	}

	return nil
}

// GetFlagsAndAnnotations retrieves the injector config flags and annotations
// based on the options provided
func (iv *InjectValidator) GetFlagsAndAnnotations() ([]string, map[string]string) {
	annotations := make(map[string]string)
	var flags []string

	if iv.AutoInject {
		annotations[k8s.ProxyInjectAnnotation] = k8s.ProxyInjectEnabled
	}

	if iv.AdminPort != 0 {
		annotations[k8s.ProxyAdminPortAnnotation] = strconv.Itoa(iv.AdminPort)
		flags = append(flags, fmt.Sprintf("--admin-port=%s", strconv.Itoa(iv.AdminPort)))
	}

	if iv.ControlPort != 0 {
		annotations[k8s.ProxyControlPortAnnotation] = strconv.Itoa(iv.ControlPort)
		flags = append(flags, fmt.Sprintf("--control-port=%s", strconv.Itoa(iv.ControlPort)))
	}

	if iv.DisableIdentity {
		annotations[k8s.IdentityModeDisabled] = enabled
		flags = append(flags, "--disable-identity")
	}

	if iv.DisableTap {
		annotations[k8s.ProxyDisableTapAnnotation] = enabled
		flags = append(flags, "--disable-tap")
	}

	if iv.EnableDebug {
		annotations[k8s.ProxyEnableDebugAnnotation] = enabled
		flags = append(flags, "--enable-debug-sidecar")
	}

	if iv.EnableExternalProfiles {
		annotations[k8s.ProxyEnableExternalProfilesAnnotation] = enabled
		flags = append(flags, "--enable-external-profiles")
	}

	if iv.ImagePullPolicy != "" {
		annotations[k8s.ProxyImagePullPolicyAnnotation] = iv.ImagePullPolicy
		flags = append(flags, fmt.Sprintf("--image-pull-policy=%s", iv.ImagePullPolicy))
	}

	if iv.InboundPort != 0 {
		annotations[k8s.ProxyInboundPortAnnotation] = strconv.Itoa(iv.InboundPort)
		flags = append(flags, fmt.Sprintf("--inbound-port=%s", strconv.Itoa(iv.InboundPort)))
	}

	if iv.InitImage != "" {
		annotations[k8s.ProxyInitImageAnnotation] = iv.InitImage
		flags = append(flags, fmt.Sprintf("--init-image=%s", iv.InitImage))
	}

	if iv.InitImageVersion != "" {
		annotations[k8s.ProxyInitImageVersionAnnotation] = iv.InitImageVersion
		flags = append(flags, fmt.Sprintf("--init-image-version=%s", iv.InitImageVersion))
	}

	if iv.OutboundPort != 0 {
		annotations[k8s.ProxyOutboundPortAnnotation] = strconv.Itoa(iv.OutboundPort)
		flags = append(flags, fmt.Sprintf("--outbound-port=%s", strconv.Itoa(iv.OutboundPort)))
	}

	if iv.CPULimit != "" {
		annotations[k8s.ProxyCPULimitAnnotation] = iv.CPULimit
		flags = append(flags, fmt.Sprintf("--proxy-cpu-limit=%s", iv.CPULimit))
	}

	if iv.CPURequest != "" {
		annotations[k8s.ProxyCPURequestAnnotation] = iv.CPURequest
		flags = append(flags, fmt.Sprintf("--proxy-cpu-request=%s", iv.CPURequest))
	}

	if iv.MemoryLimit != "" {
		annotations[k8s.ProxyMemoryLimitAnnotation] = iv.MemoryLimit
		flags = append(flags, fmt.Sprintf("--proxy-memory-limit=%s", iv.MemoryLimit))
	}

	if iv.MemoryRequest != "" {
		annotations[k8s.ProxyMemoryRequestAnnotation] = iv.MemoryRequest
		flags = append(flags, fmt.Sprintf("--proxy-memory-request=%s", iv.MemoryRequest))
	}

	if iv.Image != "" {
		annotations[k8s.ProxyImageAnnotation] = iv.Image
		flags = append(flags, fmt.Sprintf("--proxy-image=%s", iv.Image))
	}

	if iv.LogLevel != "" {
		annotations[k8s.ProxyLogLevelAnnotation] = iv.LogLevel
		flags = append(flags, fmt.Sprintf("--proxy-log-level=%s", iv.LogLevel))
	}

	if iv.LogFormat != "" {
		annotations[k8s.ProxyLogFormatAnnotation] = iv.LogFormat
	}

	if iv.UID != 0 {
		annotations[k8s.ProxyUIDAnnotation] = strconv.Itoa(iv.UID)
		flags = append(flags, fmt.Sprintf("--proxy-uid=%s", strconv.Itoa(iv.UID)))
	}

	if iv.Version != "" {
		annotations[k8s.ProxyVersionOverrideAnnotation] = iv.Version
		flags = append(flags, fmt.Sprintf("--proxy-version=%s", iv.Version))
	}

	if iv.RequireIdentityOnPorts != "" {
		annotations[k8s.ProxyRequireIdentityOnInboundPortsAnnotation] = iv.RequireIdentityOnPorts
		flags = append(flags, fmt.Sprintf("--require-identity-on-inbound-ports =%s", iv.RequireIdentityOnPorts))
	}

	if iv.SkipInboundPorts != "" {
		annotations[k8s.ProxyIgnoreInboundPortsAnnotation] = iv.SkipInboundPorts
		flags = append(flags, fmt.Sprintf("--skip-inbound-ports=%s", iv.SkipInboundPorts))
	}

	if iv.OpaquePorts != "" {
		annotations[k8s.ProxyOpaquePortsAnnotation] = iv.OpaquePorts
	}

	if iv.SkipOutboundPorts != "" {
		annotations[k8s.ProxyIgnoreOutboundPortsAnnotation] = iv.SkipOutboundPorts
		flags = append(flags, fmt.Sprintf("--skip-outbound-ports=%s", iv.SkipOutboundPorts))
	}

	if iv.TraceCollector != "" {
		annotations[k8s.ProxyTraceCollectorSvcAddrAnnotation] = iv.TraceCollector
	}

	if iv.TraceCollectorSvcAccount != "" {
		annotations[k8s.ProxyTraceCollectorSvcAccountAnnotation] = iv.TraceCollectorSvcAccount
	}

	if iv.OutboundConnectTimeout != "" {
		annotations[k8s.ProxyOutboundConnectTimeout] = iv.OutboundConnectTimeout
	}

	if iv.InboundConnectTimeout != "" {
		annotations[k8s.ProxyInboundConnectTimeout] = iv.InboundConnectTimeout
	}

	if iv.WaitBeforeExitSeconds != 0 {
		annotations[k8s.ProxyWaitBeforeExitSecondsAnnotation] = strconv.Itoa(iv.WaitBeforeExitSeconds)
		flags = append(flags, fmt.Sprintf("--wait-before-exit-secondst=%s", strconv.Itoa(iv.WaitBeforeExitSeconds)))

	}

	return flags, annotations
}
